/***************************************************************************
 *
 *   Copyright (C) 2023 by Willem van Straten
 *   Licensed under the Academic Free License version 2.1
 *
 ***************************************************************************/

#ifndef __dsp_BlockConstraints_h
#define __dsp_BlockConstraints_h

#include <string>
#include <iostream>

namespace dsp
{
//! Configures the minimum_samples and input_overlap for LoadToFold and LoadToFITS

template<class Pipeline, class Config>
void convolution_block_constraints(Pipeline* pipeline, Config* config, const std::string& app, bool report_vitals)
{
  unsigned minimum_samples = 0;
  unsigned input_overlap = 0;

  if (pipeline->filterbank)
  {
    minimum_samples = pipeline->filterbank->get_minimum_samples ();
    if (report_vitals)
    {
      std::cerr << app << ": " << config->filterbank.get_nchan() << " channel ";

      if ((config->coherent_dedispersion || config->coherent_derotation) &&
          config->filterbank.get_convolve_when() == Filterbank::Config::During)
      	std::cerr << "convolving ";

      else if (pipeline->filterbank->get_freq_res() > 1)
        std::cerr << "by " << pipeline->filterbank->get_freq_res() << " back ";

      std::cerr << "filterbank requires " << minimum_samples << " samples" << std::endl;
    }

    if (!config->input_buffering)
    {
      input_overlap = pipeline->filterbank->get_minimum_samples_lost ();
      if (Operation::verbose)
        std::cerr << "filterbank loses " << input_overlap << " samples" << std::endl;
    }
  }

  uint64_t filterbank_resolution = 0;

  if (pipeline->convolution)
  {
    filterbank_resolution = minimum_samples - input_overlap;

    const Observation* info = pipeline->get_source()->get_info();
    unsigned fb_factor = pipeline->convolution->get_input()->get_nchan() * 2;
    fb_factor /= info->get_nchan() * info->get_ndim();

    minimum_samples = pipeline->convolution->get_minimum_samples () * fb_factor;
    if (report_vitals)
      std::cerr << app << ": convolution requires at least "
           << minimum_samples << " samples" << std::endl;

    if (!config->input_buffering)
    {
      input_overlap = pipeline->convolution->get_minimum_samples_lost () * fb_factor;
      if (Operation::verbose)
        std::cerr << "convolution loses " << input_overlap << " samples" << std::endl;
    }
  }

  if (input_overlap)
  {
    // use the pipeline to configure the default block size
    pipeline->set_block_size(minimum_samples,input_overlap);

    uint64_t block_size = pipeline->get_source()->get_block_size();
    input_overlap = pipeline->get_source()->get_overlap();

    double stride = minimum_samples - input_overlap;
    double parts = (block_size - input_overlap) / stride;

    if (Operation::verbose)
      std::cerr << app << ": block_size=" << block_size << " overlap=" << input_overlap
      << " stride=" << stride << " parts=" << parts << std::endl;

    uint64_t block_resize = unsigned(parts)*(minimum_samples - input_overlap) + input_overlap;
    if (Operation::verbose)
      std::cerr << app << ": block_resize=" << block_resize << std::endl;

    if (filterbank_resolution)
    {
      if (Operation::verbose)
        std::cerr << app << ": filterbank_resolution=" << filterbank_resolution << std::endl;

      // search for a block size that suits both Filterbank and Convolution
      unsigned trial_block_size = filterbank_resolution;
      unsigned best_npart = 0;
      while (trial_block_size < block_size)
      {
        if (Operation::verbose)
          std::cerr << app << ": trial_block_size=" << trial_block_size << std::endl;

        double trial_parts = (trial_block_size-input_overlap) / stride;
        if (trial_parts == unsigned(trial_parts))
          best_npart = trial_block_size / filterbank_resolution;

        trial_block_size += filterbank_resolution;
      }

      if (best_npart == 0)
        throw Error (InvalidState, app,
                      "could not find an overlapping block size "
                      "for both Filterbank and Convolution");

      // WvS to-do: if filterbank also loses samples, then add nlost here
      block_resize = best_npart * filterbank_resolution;
    }

    if (Operation::verbose)
      std::cerr << app << ": old=" << block_size << " new=" << block_resize << std::endl;

    minimum_samples = block_resize;
  }

  pipeline->set_block_size(minimum_samples,input_overlap);
}

} // namespace dsp

#endif
